import pennylane as qml
from pennylane import numpy as np
import numpy.linalg as la
from datetime import datetime

def f_n(weights, ansatz=None):
    return np.sum(ansatz(weights))

def apfa_schedule_by_grad(it, grad_norm, freeze_factor_init, activate_factor_init, freeze_count_th_init, activate_count_th_init):
    if it == 0:
        return freeze_factor_init, activate_factor_init, freeze_count_th_init, activate_count_th_init
    current_gn = grad_norm[it]
    init_gn = grad_norm[0] if grad_norm[0] > 1e-12 else 1e-12
    ratio = current_gn / init_gn
    if ratio < 0.2:
        freeze_factor = freeze_factor_init * 3.0
        activate_factor = activate_factor_init * 3.0
        freeze_count_th = freeze_count_th_init + 2
        activate_count_th = activate_count_th_init + 1
    else:
        scale = 1.0 + (1.0 - ratio)/0.8 * 2.0
        freeze_factor = freeze_factor_init * scale
        activate_factor = activate_factor_init * scale
        freeze_count_th = freeze_count_th_init
        activate_count_th = activate_count_th_init
    return freeze_factor, activate_factor, freeze_count_th, activate_count_th

def gd_optimizer_apfa(ansatz, weights, noise_gamma, lr, iteration, n_check, alpha=0.7, freeze_factor=0.2, activate_factor=0.4, freeze_count_th=3, activate_count_th=2, warmup_steps=20):
    grad_func = qml.grad(f_n)
    n_params = len(weights)
    freeze_mask = np.array([False]*n_params)
    freeze_count = np.zeros(n_params, dtype=int)
    activate_count = np.zeros(n_params, dtype=int)
    grad_norm = np.zeros(iteration)
    loss = np.zeros(iteration)
    grad_ema = np.zeros(n_params)
    t1 = datetime.now()
    for it in range(iteration):
        weights.requires_grad = True
        raw_loss = f_n(weights, ansatz=ansatz)
        loss[it] = raw_loss
        gradient_now = grad_func(weights, ansatz=ansatz)
        grad_norm[it] = la.norm(gradient_now)
        for i in range(n_params):
            grad_ema[i] = alpha*grad_ema[i] + (1 - alpha)*abs(gradient_now[i])
        ff, af, fct, act = apfa_schedule_by_grad(it, grad_norm, freeze_factor, activate_factor, freeze_count_th, activate_count_th)
        global_grad_ema = np.mean(grad_ema)
        freeze_threshold = ff * global_grad_ema
        activate_threshold = af * freeze_threshold
        if it >= warmup_steps:
            for i in range(n_params):
                if freeze_mask[i]:
                    if grad_ema[i] > activate_threshold:
                        activate_count[i] += 1
                    else:
                        activate_count[i] = 0
                    if activate_count[i] >= act:
                        freeze_mask[i] = False
                        activate_count[i] = 0
                else:
                    if grad_ema[i] < freeze_threshold:
                        freeze_count[i] += 1
                    else:
                        freeze_count[i] = 0
                    if freeze_count[i] >= fct:
                        freeze_mask[i] = True
                        freeze_count[i] = 0
        noise = np.random.normal(0, noise_gamma, n_params)
        effective_grad = gradient_now + noise
        effective_grad[freeze_mask] = 0.0
        weights = weights - lr * effective_grad

        if it % n_check == 0:
            t2 = datetime.now()
            print(f"[APFA GD] iter={it}, loss={loss[it]:.6f}, gradnorm={grad_norm[it]:.6f}, freeze={np.sum(freeze_mask)}, time={(t2 - t1).seconds}s")
            t1 = t2
    return loss, grad_norm, weights, freeze_mask

def adam_optimizer_apfa(ansatz, weights, noise_gamma, lr, iteration, n_check, alpha=0.7, freeze_factor=0.5, activate_factor=1.0, freeze_count_th=3, activate_count_th=2, warmup_steps=20):
    beta_1 = 0.9
    beta_2 = 0.99
    epsilon = 1e-8
    grad_func = qml.grad(f_n)
    n_params = len(weights)
    freeze_mask = np.array([False]*n_params)
    freeze_count = np.zeros(n_params, dtype=int)
    activate_count = np.zeros(n_params, dtype=int)
    m = np.zeros(n_params)
    v = np.zeros(n_params)
    grad_norm = np.zeros(iteration)
    loss = np.zeros(iteration)
    grad_ema = np.zeros(n_params)
    t1 = datetime.now()
    for it in range(iteration):
        weights.requires_grad = True
        raw_loss = f_n(weights, ansatz=ansatz)
        loss[it] = raw_loss
        gradient_now = grad_func(weights, ansatz=ansatz)
        grad_norm[it] = la.norm(gradient_now)
        for i in range(n_params):
            grad_ema[i] = alpha*grad_ema[i] + (1 - alpha)*abs(gradient_now[i])
        ff, af, fct, act = apfa_schedule_by_grad(it, grad_norm, freeze_factor, activate_factor, freeze_count_th, activate_count_th)
        global_grad_ema = np.mean(grad_ema)
        freeze_threshold = ff * global_grad_ema
        activate_threshold = af * freeze_threshold
        if it >= warmup_steps:
            for i in range(n_params):
                if freeze_mask[i]:
                    if grad_ema[i] > activate_threshold:
                        activate_count[i] += 1
                    else:
                        activate_count[i] = 0
                    if activate_count[i] >= act:
                        freeze_mask[i] = False
                        activate_count[i] = 0
                else:
                    if grad_ema[i] < freeze_threshold:
                        freeze_count[i] += 1
                    else:
                        freeze_count[i] = 0
                    if freeze_count[i] >= fct:
                        freeze_mask[i] = True
                        freeze_count[i] = 0
        noise = np.random.normal(0, noise_gamma, n_params)
        effective_grad = gradient_now + noise
        effective_grad[freeze_mask] = 0.0
        m = beta_1*m + (1-beta_1)*effective_grad
        v = beta_2*v + (1-beta_2)*(effective_grad**2)
        m_hat = m / (1 - beta_1**(it+1))
        v_hat = v / (1 - beta_2**(it+1))
        weights = weights - lr*m_hat/(np.sqrt(v_hat) + epsilon)
        if it % n_check == 0:
            t2 = datetime.now()
            print(f"[APFA Adam] iter={it}, loss={loss[it]:.6f}, gradnorm={grad_norm[it]:.6f}, freeze={np.sum(freeze_mask)}, time={(t2 - t1).seconds}s")
            t1 = t2
    return loss, grad_norm, weights